#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# @lint-avoid-python-3-compatibility-imports
#
# cachetop      Count cache kernel function calls per processes
#               For Linux, uses BCC, eBPF.
#
# USAGE: cachetop
# Taken from cachestat by Brendan Gregg
#
# Copyright (c) 2016-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0 (the "License")
#
# 13-Jul-2016   Emmanuel Bretelle first version

from __future__ import absolute_import
from __future__ import division
# Do not import unicode_literals until #623 is fixed
# from __future__ import unicode_literals
from __future__ import print_function

from bcc import BPF
from collections import defaultdict
from time import strftime
import sys
sys.path.append('./plugins/common/')
from db_modules import write2db
from datetime import datetime
from init_db import influx_client
from const import DatabaseType

import argparse
import curses
import pwd
import re
import signal
from time import sleep

FIELDS = (
    "PID",
    "UID",
    "CMD",
    "HITS",
    "MISSES",
    "DIRTIES",
    "READ_HIT%",
    "WRITE_HIT%"
)
DEFAULT_FIELD = "HITS"
DEFAULT_SORT_FIELD = FIELDS.index(DEFAULT_FIELD)

# signal handler
def signal_ignore(signal, frame):
    print()


# Function to gather data from /proc/meminfo
# return dictionary for quicker lookup of both values
def get_meminfo():
    result = {}

    for line in open('/proc/meminfo'):
        k = line.split(':', 3)
        v = k[1].split()
        result[k[0]] = int(v[0])
    return result


def get_processes_stats(
        bpf,
        sort_field=DEFAULT_SORT_FIELD,
        sort_reverse=False):
    '''
    Return a tuple containing:
    buffer
    cached
    list of tuple with per process cache stats
    '''
    counts = bpf.get_table("counts")
    stats = defaultdict(lambda: defaultdict(int))
    for k, v in counts.items():
        stats["%d-%d-%s" % (k.pid, k.uid, k.comm.decode('utf-8', 'replace'))][k.ip] = v.value
    stats_list = []

    for pid, count in sorted(stats.items(), key=lambda stat: stat[0]):
        rtaccess = 0
        wtaccess = 0
        mpa = 0
        mbd = 0
        apcl = 0
        apd = 0
        access = 0
        misses = 0
        rhits = 0
        whits = 0

        for k, v in count.items():
            if re.match(b'mark_page_accessed', bpf.ksym(k)) is not None:
                mpa = max(0, v)

            if re.match(b'mark_buffer_dirty', bpf.ksym(k)) is not None:
                mbd = max(0, v)

            if re.match(b'add_to_page_cache_lru', bpf.ksym(k)) is not None:
                apcl = max(0, v)

            if re.match(b'account_page_dirtied', bpf.ksym(k)) is not None:
                apd = max(0, v)

            # access = total cache access incl. reads(mpa) and writes(mbd)
            # misses = total of add to lru which we do when we write(mbd)
            # and also the mark the page dirty(same as mbd)
            access = (mpa + mbd)
            misses = (apcl + apd)

            # rtaccess is the read hit % during the sample period.
            # wtaccess is the write hit % during the sample period.
            if mpa > 0:
                rtaccess = float(mpa) / (access + misses)
            if apcl > 0:
                wtaccess = float(apcl) / (access + misses)

            if wtaccess != 0:
                whits = 100 * wtaccess
            if rtaccess != 0:
                rhits = 100 * rtaccess

        _pid, uid, comm = pid.split('-', 2)
        stats_list.append(
            (int(_pid), uid, comm,
             access, misses, mbd,
             rhits, whits))

    stats_list = sorted(
        stats_list, key=lambda stat: stat[sort_field], reverse=sort_reverse
    )
    counts.clear()
    return stats_list

# data structure from template
class lmp_data(object):
    def __init__(self,a,b,c,d,e,f,g,h,i,j):
            self.time = a
            self.glob = b
            self.pid = c
            self.uid = d
            self.cmd = e
            self.hits = f
            self.misses = g
            self.DIRTIES = h
            self.READ_HIT = i
            self.WRITE_HIT = j
            
                    
data_struct = {"measurement":'cachetop',
               "time":[],
               "tags":['glob',],
               "fields":['time','pid','uid','cmd','hits','misses','DIRTIES','READ_HIT','WRITE_HIT']}

# load BPF program
bpf_text = """

#include <uapi/linux/ptrace.h>
struct key_t {
    u64 ip;
    u32 pid;
    u32 uid;
    char comm[16];
};

BPF_HASH(counts, struct key_t);

int do_count(struct pt_regs *ctx) {
    struct key_t key = {};
    u64 pid = bpf_get_current_pid_tgid();
    u32 uid = bpf_get_current_uid_gid();

    key.ip = PT_REGS_IP(ctx);
    key.pid = pid >> 32;
    key.uid = uid;
    bpf_get_current_comm(&(key.comm), 16);

    counts.increment(key);
    return 0;
}

"""
b = BPF(text=bpf_text)
b.attach_kprobe(event="add_to_page_cache_lru", fn_name="do_count")
b.attach_kprobe(event="mark_page_accessed", fn_name="do_count")
b.attach_kprobe(event="account_page_dirtied", fn_name="do_count")
b.attach_kprobe(event="mark_buffer_dirty", fn_name="do_count")
exiting = 0
while 1:
    # s = stdscr.getch()
    # if s == ord('q'):
    #     exiting = 1
    # elif s == ord('r'):
    #     sort_reverse = not sort_reverse
    # elif s == ord('<'):
    #     sort_field = max(0, sort_field - 1)
    # elif s == ord('>'):
    #     sort_field = min(len(FIELDS) - 1, sort_field + 1)
    # try:
    #     sleep(args.interval)
    # except KeyboardInterrupt:
    #     exiting = 1
    #     # as cleanup can take many seconds, trap Ctrl-C:
    #     signal.signal(signal.SIGINT, signal_ignore)

    # Get memory info
    mem = get_meminfo()
    cached = int(mem["Cached"]) / 1024
    buff = int(mem["Buffers"]) / 1024
    process_stats = get_processes_stats(
        b,
        # 
        )
    
    # stdscr.addstr(
    #     0, 0,
    #     "%-8s Buffers MB: %.0f / Cached MB: %.0f "
    #     "/ Sort: %s / Order: %s" % (
    #         strftime("%H:%M:%S"), buff, cached, FIELDS[sort_field],
    #         sort_reverse and "descending" or "ascending"
    #     )
    # )
    # header
    # stdscr.addstr(
    #     1, 0,
    #     "{0:8} {1:8} {2:16} {3:8} {4:8} {5:8} {6:10} {7:10}".format(
    #         *FIELDS
    #     ),
        
    # )
    
    for i, stat in enumerate(process_stats):
        uid = int(stat[1])
        try:
            username = pwd.getpwuid(uid)[0]
            # print(datetime.now().isoformat(),'glob',stat[0], username, stat[2],stat[3],stat[4],stat[5],stat[6],stat[7])
            test_data = lmp_data(datetime.now().isoformat(),'glob',stat[0], username, stat[2],stat[3],stat[4],stat[5],float(stat[6]),float(stat[7]))
            write2db(data_struct, test_data, influx_client, DatabaseType.INFLUXDB.value)
        except KeyError:
            # `pwd` throws a KeyError if the user cannot be found. This can
            # happen e.g. when the process is running in a cgroup that has
            # different users from the host.
            username = 'UNKNOWN({})'.format(uid)
        # stdscr.addstr(
        #     i + 2, 0,
        #     "{0:8} {username:8.8} {2:16} {3:8} {4:8} "
        #     "{5:8} {6:9.1f}% {7:9.1f}%".format(
        #         *stat, username=username
        #     )
        # )
        # if i > height - 4:
        #     break
    
    # if exiting:
    #     print("Detaching...")
    #     return


def parse_arguments():
    parser = argparse.ArgumentParser(
        description='show Linux page cache hit/miss statistics including read '
                    'and write hit % per processes in a UI like top.'
    )
    parser.add_argument(
        'interval', type=int, default=5, nargs='?',
        help='Interval between probes.'
    )

    args = parser.parse_args()
    return args

args = parse_arguments()

